#!/home/unitree/miniconda3/envs/x5/bin/python3
# -- coding: UTF-8


import torch
import numpy as np
import os
import pickle
import argparse
from einops import rearrange
import h5py

from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
import collections
from collections import deque
import cv2
import time
import threading
import math
import threading


import sys
sys.path.append("./")

task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}

inference_thread = None
inference_lock = threading.Lock()
inference_actions = None
inference_timestep = None

def read_single_episode(episode_path, index=0, camera_names=None, use_depth_image=False, use_robot_base=False):
    """Read a single episode from HDF5 file format
    
    Args:
        episode_path: Path to HDF5 file
        index: Index of the data point to read (default: 0)
        camera_names: List of camera names to read
        use_depth_image: Whether to read depth images
        use_robot_base: Whether to include robot base data
    """
    with h5py.File(episode_path, 'r') as root:
        is_sim = root.attrs['sim'] 
        is_compress = root.attrs['compress']
        
        # Read actions and states
        actions = root['/action'][index]
        qpos = root['/observations/qpos'][index]
        
        if use_robot_base:
            base_action = root['/base_action'][index]
            qpos = np.concatenate((qpos, base_action), axis=0)
            actions = np.concatenate((actions, base_action), axis=0)
            
        # Read images
        image_dict = {}
        image_depth_dict = {}
        for cam_name in camera_names:
            if is_compress:
                decoded_image = root[f'/observations/images/{cam_name}'][index]
                image_dict[cam_name] = cv2.imdecode(decoded_image, 1)
            else:
                image_dict[cam_name] = root[f'/observations/images/{cam_name}'][index]
                
            if use_depth_image:
                image_depth_dict[cam_name] = root[f'/observations/images_depth/{cam_name}'][index]
                
        episode = {
            'is_sim': is_sim,
            'actions': actions,
            'qpos': qpos,
            'images': image_dict
        }
        
        if use_depth_image:
            episode['images_depth'] = image_depth_dict
            
        return episode


def actions_interpolation(args, pre_action, actions, stats):
    steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
    pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
    post_process = lambda a: a * stats['qpos_std'] + stats['qpos_mean']
    result = [pre_action]
    post_action = post_process(actions[0])
    print("actions_interpolation1:", post_action[:, 7:])
    max_diff_index = 0
    max_diff = -1
    for i in range(post_action.shape[0]):
        diff = 0
        for j in range(pre_action.shape[0]):
            if j == 6 or j == 13:
                continue
            diff += math.fabs(pre_action[j] - post_action[i][j])
        if diff > max_diff:
            max_diff = diff
            max_diff_index = i

    for i in range(max_diff_index, post_action.shape[0]):
        step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
        inter = np.linspace(result[-1], post_action[i], step+2)
        result.extend(inter[1:])
    while len(result) < args.chunk_size+1:
        result.append(result[-1])
    result = np.array(result)[1:args.chunk_size+1]
    print("actions_interpolation2:", result.shape, result[:, 7:])
    result = pre_process(result)
    result = result[np.newaxis, :]
    return result


def get_model_config(args):
    # 设置随机种子，你可以确保在相同的初始条件下，每次运行代码时生成的随机数序列是相同的。
    set_seed(1)
   
    # 如果是ACT策略
    # fixed parameters
    if args.policy_class == 'ACT':
        policy_config = {'lr': args.lr,
                         'lr_backbone': args.lr_backbone,
                         'backbone': args.backbone,
                         'masks': args.masks,
                         'weight_decay': args.weight_decay,
                         'dilation': args.dilation,
                         'position_embedding': args.position_embedding,
                         'loss_function': args.loss_function,
                         'chunk_size': args.chunk_size,     # 查询
                         'camera_names': task_config['camera_names'],
                         'use_depth_image': args.use_depth_image,
                         'use_robot_base': args.use_robot_base,
                         'kl_weight': args.kl_weight,        # kl散度权重
                         'hidden_dim': args.hidden_dim,      # 隐藏层维度
                         'dim_feedforward': args.dim_feedforward,
                         'enc_layers': args.enc_layers,
                         'dec_layers': args.dec_layers,
                         'nheads': args.nheads,
                         'dropout': args.dropout,
                         'pre_norm': args.pre_norm
                         }
    elif args.policy_class == 'CNNMLP':
        policy_config = {'lr': args.lr,
                         'lr_backbone': args.lr_backbone,
                         'backbone': args.backbone,
                         'masks': args.masks,
                         'weight_decay': args.weight_decay,
                         'dilation': args.dilation,
                         'position_embedding': args.position_embedding,
                         'loss_function': args.loss_function,
                         'chunk_size': 1,     # 查询
                         'camera_names': task_config['camera_names'],
                         'use_depth_image': args.use_depth_image,
                         'use_robot_base': args.use_robot_base
                         }

    elif args.policy_class == 'Diffusion':
        policy_config = {'lr': args.lr,
                         'lr_backbone': args.lr_backbone,
                         'backbone': args.backbone,
                         'masks': args.masks,
                         'weight_decay': args.weight_decay,
                         'dilation': args.dilation,
                         'position_embedding': args.position_embedding,
                         'loss_function': args.loss_function,
                         'chunk_size': args.chunk_size,     # 查询
                         'camera_names': task_config['camera_names'],
                         'use_depth_image': args.use_depth_image,
                         'use_robot_base': args.use_robot_base,
                         'observation_horizon': args.observation_horizon,
                         'action_horizon': args.action_horizon,
                         'num_inference_timesteps': args.num_inference_timesteps,
                         'ema_power': args.ema_power
                         }
    else:
        raise NotImplementedError

    config = {
        'ckpt_dir': args.ckpt_dir,
        'ckpt_name': args.ckpt_name,
        'ckpt_stats_name': args.ckpt_stats_name,
        'max_publish_step': args.max_publish_step,
        'state_dim': args.state_dim,
        'policy_class': args.policy_class,
        'policy_config': policy_config,
        'temporal_agg': args.temporal_agg,
        'camera_names': task_config['camera_names'],
    }
    return config


def make_policy(policy_class, policy_config):
    if policy_class == 'ACT':
        policy = ACTPolicy(policy_config)
    elif policy_class == 'CNNMLP':
        policy = CNNMLPPolicy(policy_config)
    elif policy_class == 'Diffusion':
        policy = DiffusionPolicy(policy_config)
    else:
        raise NotImplementedError
    return policy


def get_image(observation, camera_names):
    curr_images = []
    for cam_name in camera_names:
        curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
        curr_images.append(curr_image)
    curr_image = np.stack(curr_images, axis=0)
    curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
    return curr_image


def get_depth_image(observation, camera_names):
    curr_images = []
    for cam_name in camera_names:
        curr_images.append(observation['images_depth'][cam_name])
    curr_image = np.stack(curr_images, axis=0)
    curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
    return curr_image


def inference_process(args, config, episode_path, policy, stats, t, pre_action):
    """
    读取数据并进行单次推理
    """
    global inference_lock
    global inference_actions 
    global inference_timestep
    print_flag = True
    pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
    pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
    
    print("Reading frame from episode:", t)
    # 从HDF5文件读取单帧数据
    obs = read_single_episode(
        episode_path,
        index=t,
        camera_names=config['camera_names'],
        use_depth_image=args.use_depth_image,
        use_robot_base=args.use_robot_base
    )
    
    if obs is None:
        print("Failed to read frame")
        return False
        
    # 打印输入的action
    print(f"\nInput action at t={t}:", obs['actions'])
        
    # 归一化处理qpos并转到cuda
    qpos = pre_pos_process(obs['qpos'])
    qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
    
    # 获取当前图像
    curr_image = get_image(obs, config['camera_names'])
    curr_depth_image = None
    if args.use_depth_image:
        curr_depth_image = get_depth_image(obs, config['camera_names'])
        
    start_time = time.time()
    all_actions = policy(curr_image, curr_depth_image, qpos)
    end_time = time.time()
    print("Model inference time:", end_time - start_time)
    
    # 打印模型输出的第一个action (包括归一化和反归一化的值)
    normalized_action = all_actions[0][0].cpu().detach().numpy()
    denormalized_action = normalized_action * stats['qpos_std'] + stats['qpos_mean']
    print(f"Model output action at t={t}:")
    # print("- Normalized:", normalized_action)
    print("- Denormalized:", denormalized_action)
    
    inference_lock.acquire()
    inference_actions = all_actions.cpu().detach().numpy()
    if pre_action is None:
        pre_action = obs['qpos']
    if args.use_actions_interpolation:
        inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
    inference_timestep = t
    inference_lock.release()


def model_inference(args, config, episode_path, save_episode=True):
    """
    加载模型，并进行推理
    """
    global inference_lock
    global inference_actions
    global inference_timestep
    global inference_thread
    
    set_seed(1000)

    # 创建并加载模型
    policy = make_policy(config['policy_class'], config['policy_config'])
    ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
    state_dict = torch.load(ckpt_path,weights_only=True)
    new_state_dict = {}
    for key, value in state_dict.items():
        if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
            continue
        if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
            continue
        new_state_dict[key] = value
    loading_status = policy.deserialize(new_state_dict)
    if not loading_status:
        print("ckpt path not exist")
        return False

    # 模型设置为cuda模式和验证模式
    policy.cuda()
    policy.eval()

    # 加载统计值
    stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
    with open(stats_path, 'rb') as f:
        stats = pickle.load(f)

    pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
    post_process = lambda a: a * stats['qpos_std'] + stats['qpos_mean']

    max_publish_step = config['max_publish_step']
    chunk_size = config['policy_config']['chunk_size']

    print("Starting inference...")
    with torch.inference_mode():
        t = 0
        max_t = 0
        if config['temporal_agg']:
            all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
        while t < max_publish_step:
            if config['policy_class'] == "ACT":
                if t >= max_t:
                    pre_action = action if 'action' in locals() else None
                    print("Inferencing frame:", t)
                    inference_thread = threading.Thread(
                        target=inference_process,
                        args=(args, config, episode_path, policy, stats, t, pre_action)
                    )
                    inference_thread.start()
                    inference_thread.join()
                    
                    inference_lock.acquire()
                    if inference_actions is not None:
                        inference_thread = None
                        all_actions = inference_actions
                        inference_actions = None
                        max_t = t + args.pos_lookahead_step
                        if config['temporal_agg']:
                            all_time_actions[[t], t:t + chunk_size] = all_actions
                    inference_lock.release()
                    
                if config['temporal_agg']:
                    actions_for_curr_step = all_time_actions[:, t]
                    actions_populated = np.all(actions_for_curr_step != 0, axis=1)
                    actions_for_curr_step = actions_for_curr_step[actions_populated]
                    k = 0.01
                    exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
                    exp_weights = exp_weights / exp_weights.sum()
                    exp_weights = exp_weights[:, np.newaxis]
                    raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
                else:
                    if args.pos_lookahead_step != 0:
                        raw_action = all_actions[:, t % args.pos_lookahead_step]
                    else:
                        raw_action = all_actions[:, t % chunk_size]
                action = post_process(raw_action[0])
                left_action = action[:7]  # 取7维度
                right_action = action[7:14]
                t += 1


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
    parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
    parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
    parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
    parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
    parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
    parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
    parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
    parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
    parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
    parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features", required=False)
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")
    parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
    parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
    parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
    parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)

    parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
    parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
    parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
    parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
    parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
    parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
    parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
    parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
    parser.add_argument('--pre_norm', action='store_true', required=False)

    parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
                        default='/camera_t/color/image_raw', required=False)
    parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
                        default='/camera_l/color/image_raw', required=False)
    parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
                        default='/camera_r/color/image_raw', required=False)
    
    parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
                        default='/camera_t/depth/image_raw', required=False)
    parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
                        default='/camera_l/depth/image_raw', required=False)
    parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
                        default='/camera_r/depth/image_raw', required=False)
    
    parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
                        default='/master/joint_left', required=False)
    parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
                        default='/master/joint_right', required=False)
    parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
                        default='/puppet/joint_left', required=False)
    parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
                        default='/puppet/joint_right', required=False)
    
    parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
                        default='/odom', required=False)
    parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
                        default='/cmd_vel', required=False)
    parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
                        default=False, required=False)
    parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
                        default=40, required=False)
    parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
                        default=0, required=False)
    parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
                        default=32, required=False)
    parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
                        default=[0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.2], required=False)

    parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
                        default=False, required=False)
    parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
                        default=False, required=False)
    parser.add_argument('--episode_path', type=str, required=True,
                      help='Path to the HDF5 episode file')

    # for Diffusion
    parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
    parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
    parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
    parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)

    args = parser.parse_args()
    return args


def main():
    args = get_arguments()
    if args.use_robot_base:
        args.state_dim += 2
        
    config = get_model_config(args)
    model_inference(args, config, args.episode_path, save_episode=True)


if __name__ == '__main__':
    main()

# python evaluate.py --episode_path /gemini/data-1/ACT/move_lippie/episode_13.hdf5 --ckpt_dir /gemini/data-1/ACT/ckpt/lippie_data/ --max_publish_step 200 --chunk_size 30 | tee evl.log
